#include "yolox_detector.h"

namespace {

constexpr int kNumClasses = 80;
constexpr float kNmsThresh = 0.45;
constexpr float kConfidenceThresh = 0.5;

float *BlobFromImage(cv::Mat &img) {
  float *blob = new float[img.total() * 3];
  int channels = 3;
  int img_h = img.rows;
  int img_w = img.cols;
  for (int c = 0; c < channels; c++) {
    for (int h = 0; h < img_h; h++) {
      for (int w = 0; w < img_w; w++) {
        blob[c * img_w * img_h + h * img_w + w] =
            static_cast<float>(img.at<cv::Vec3b>(h, w)[c]);
      }
    }
  }
  return blob;
}

inline float GetBoxIoU(const Object &a, const Object &b) {
  cv::Rect2f intersection = a.box & b.box;
  const float i = intersection.area();
  const float u = a.box.area() + b.box.area() - i;
  return (i / u);
}

void SoftNMS(std::vector<Object> *const objects, const float nms_thresh,
             const float confidence_thresh) {
  std::sort(objects->begin(), objects->end(),
            [](const Object &a, const Object &b) {
              return a.confidence > b.confidence;
            });
  std::vector<Object> reserved_objects;
  while (!objects->empty()) {
    const auto obj = objects->front();
    reserved_objects.push_back(obj);
    objects->erase(objects->begin());
    for (auto iter = objects->begin(); iter != objects->end();) {
      const float iou = GetBoxIoU(obj, *iter);
      if (iou > nms_thresh) {
        const float weight = std::exp(-(iou * iou) / 0.5f);
        iter->confidence *= weight;
      }
      if (iter->confidence < confidence_thresh) {
        iter = objects->erase(iter);
      } else {
        ++iter;
      }
    }
  }
  objects->swap(reserved_objects);
}

} // namespace

bool YoloxDetector::Init() {
  tensorrt_onnx_.reset(new TensorrtOnnxInference(model_path_));
  if (!tensorrt_onnx_->Init()) {
    std::cout << "Failed to initialize tensorrt onnx inference engine!";
    return -1;
  }
  const auto model_input_dims = tensorrt_onnx_->GetModelInputDims(0);
  model_height_ = model_input_dims.first;
  model_width_ = model_input_dims.second;
  std::cout << "model input height: " << model_height_ << std::endl;
  std::cout << "model input width: " << model_width_ << std::endl;

  return true;
}

bool YoloxDetector::Detect(cv::Mat &input_image, std::vector<Object> *objects) {
  objects->clear();

  cv::Mat resize_image;
  cv::resize(input_image, resize_image, cv::Size(model_width_, model_height_));

  float *blob;
  blob = BlobFromImage(resize_image);

  const float width_scale = static_cast<float>(model_width_) / input_image.cols;
  const float height_scale =
      static_cast<float>(model_height_) / input_image.rows;

  std::vector<float *> input_data{blob};
  std::vector<const float *> output_data;

  auto start = std::chrono::steady_clock::now();
  tensorrt_onnx_->Infer(input_data, output_data);
  auto end = std::chrono::steady_clock::now();
  auto cost_time =
      std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
          .count();
  std::cout << "Model inference cost time: " << cost_time << " ms" << std::endl;

  DecodeYoloxOutputs(output_data.at(0), width_scale, height_scale,
                     kConfidenceThresh, objects);
  std::cout << "objects size before nms: " << objects->size() << std::endl;
  SoftNMS(objects, kNmsThresh, kConfidenceThresh);
  std::cout << "objects size after nms: " << objects->size() << std::endl;

  delete[] blob;

  return true;
}

void YoloxDetector::DecodeYoloxOutputs(const float *const output,
                                       const float width_scale,
                                       const float height_scale,
                                       const float confidence_thresh,
                                       std::vector<Object> *objs) {
  objs->clear();
  const std::vector<int> strides = {8, 16, 32};
  float *ptr = const_cast<float *>(output);
  for (std::size_t i = 0; i < strides.size(); ++i) {
    const int stride = strides.at(i);
    const int grid_width = model_width_ / stride;
    const int grid_height = model_height_ / stride;
    const int grid_size = grid_width * grid_height;
    for (int j = 0; j < grid_size; ++j) {
      const int row = j / grid_width;
      const int col = j % grid_width;
      const int base_pos = j * (kNumClasses + 5);
      const int class_pos = base_pos + 5;
      const float objectness = ptr[base_pos + 4];
      const int label =
          std::max_element(ptr + class_pos, ptr + class_pos + kNumClasses) -
          (ptr + class_pos);
      const float confidence = (*(ptr + class_pos + label)) * objectness;
      if (confidence > confidence_thresh) {
        const float x = (ptr[base_pos + 0] + col) * stride / width_scale;
        const float y = (ptr[base_pos + 1] + row) * stride / height_scale;
        const float w = std::exp(ptr[base_pos + 2]) * stride / width_scale;
        const float h = std::exp(ptr[base_pos + 3]) * stride / height_scale;

        Object obj;
        obj.box.x = x - w * 0.5f;
        obj.box.y = y - h * 0.5f;
        obj.box.width = w;
        obj.box.height = h;
        obj.label = label;
        obj.confidence = confidence;
        objs->push_back(std::move(obj));
      }
    }
    ptr += grid_size * (kNumClasses + 5);
  }
}
